Score Matching for Density Estimation

A first post using Quarto

Optimization
Score Matching
Julia
Score matching is a method for estimating the density of a distribution. It is a non-parametric method that uses the score function of the distribution to estimate the density. In this post, I will explain the score matching method and show how it can be used to estimate the density of a distribution, as well as some of its limitations.
Author

Simon Ghyselincks

Published

May 22, 2024

Big Idea

Denoising Autoencoders (DAE) are a type of neural network that is trained to reconstruct the input data from a noisy version of the input. The DAE is trained to minimize the reconstruction error between the input and the output. The DAE can be used to learn the score function of a distribution, which has been shown to be the kernel density estimate using the noise kernel and the original non-noisy data sample points.

Learning the DAE parameters is equivalent to learning the score function for the kernel density estimator or KDE. The score function is an operator defined as:

\[ s(f(x;\theta)) = \nabla_x \log f(x;\theta) \]

Where \(f(x;\theta)\) is the density function of the distribution and \(\theta\) are the parameters of the density function.

By learning a score function for a model, we can reverse the score operation to obtain the density function. This is the idea behind score matching, which is a method for estimating the density of a distribution indirectly by matching the gradient of the log-distribution of both as closely as possible.

Another benefit of learning the score function of a distribution is that it can be used to move from less probable regions of the distribution to more probable regions using gradient ascent. This is useful when it comes to generative models, where we want to generate new samples from the distribution that are more probable.

However one of the challenges with score matching is that the score function is not always well-defined, especially in regions of low probability. This can make it difficult to learn the score function accurately in these regions.

This post explores some of those limitations and how increasing the bandwidth of the noise kernel in the DAE can help to stabilize the score function in regions of low probability.

Sample of Score Matching

Suppose we have a distribution in 2D space that consists of three Gaussians as our ground truth. We can plot this pdf and its gradient field.

Show the code
using Plots, Distributions

# Define the ground truth distribution
function p(x, y)
    mu1, mu2, mu3 = [-1, -1], [1, 1], [1, -1]
    sigma1, sigma2, sigma3 = [0.5 0.3; 0.3 0.5], [0.5 0.3; 0.3 0.5], [0.5 0; 0 0.5]

    return 0.2 * pdf(MvNormal(mu1, sigma1), [x, y]) + 0.2 * pdf(MvNormal(mu2, sigma2), [x, y]) + 0.6 * pdf(MvNormal(mu3, sigma3), [x, y])
end

# Plot the distribution using a heatmap
heatmap(
    -3:0.01:3, -3:0.01:3, p,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Ground Truth PDF q(x)",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

Sampling from the distribution can be done by generating 100 random points

Show the code
using Plots, Distributions

# Define the ground truth distribution
function p(x, y)
    mu1, mu2, mu3 = [-1, -1], [1, 1], [1, -1]
    sigma1, sigma2, sigma3 = [0.5 0.3; 0.3 0.5], [0.5 0.3; 0.3 0.5], [0.5 0; 0 0.5]

    return 0.2 * pdf(MvNormal(mu1, sigma1), [x, y]) + 0.2 * pdf(MvNormal(mu2, sigma2), [x, y]) + 0.6 * pdf(MvNormal(mu3, sigma3), [x, y])
end

# Sample 200 points from the ground truth distribution
n_points = 200
points = []

while length(points) < n_points
    x = rand() * 6 - 3
    y = rand() * 6 - 3
    if rand() < p(x, y)
        push!(points, (x, y))
    end
end

# Plot the distribution using a heatmap
# heatmap(
#     -3:0.01:3, -3:0.01:3, p,
#     c=cgrad(:davos, rev=true),
#     aspect_ratio=:equal,
#     xlabel="x", ylabel="y", title="Ground Truth PDF q(θ)",

# )

# Scatter plot of the sampled points
scatter([x for (x, y) in points], [y for (x, y) in points], label="Sampled Points", color=:red, ms=2,
     xlims=(-3, 3), ylims=(-3, 3),
     xticks=[-3, 3], yticks=[-3, 3])

From this sampling of points we can visualize some of the key concepts of score matching. The score function is the gradient of the log-density function. In the case of the gaussian we extract the gradient of the Energy function exponential term.

Show the code
using Plots, Distributions, ForwardDiff

# Define the ground truth distribution
function p(x, y)
    mu1, mu2, mu3 = [-1, -1], [1, 1], [1, -1]
    sigma1, sigma2, sigma3 = [0.5 0.3; 0.3 0.5], [0.5 0.3; 0.3 0.5], [0.5 0; 0 0.5]

    return 0.2 * pdf(MvNormal(mu1, sigma1), [x, y]) + 0.2 * pdf(MvNormal(mu2, sigma2), [x, y]) + 0.6 * pdf(MvNormal(mu3, sigma3), [x, y])
end

# Define the log of the distribution
function log_p(x, y)
    val = p(x, y)
    return val > 0 ? log(val) : -Inf
end

# Function to compute the gradient using ForwardDiff
function gradient_log_p(u, v)
    grad = ForwardDiff.gradient(x -> log_p(x[1], x[2]), [u, v])
    return grad[1], grad[2]
end

# Generate a grid of points
xs = -3:0.5:3
ys = -3:0.5:3

# Create meshgrid manually
xxs = [x for x in xs, y in ys]
yys = [y for x in xs, y in ys]

# Compute the gradients at each point
U = []
V = []
for x in xs
    for y in ys
        u, v = gradient_log_p(x, y)

        push!(U, u)
        push!(V, v)
    end
end

# Convert U and V to arrays
U = reshape(U, length(xs), length(ys))
V = reshape(V, length(xs), length(ys))

# Plot the distribution using a heatmap
heatmap(
    -3:0.01:3, -3:0.01:3, p,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Ground Truth PDF q(x) with score",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

# Flatten the gradients and positions for quiver plot
xxs_flat = [x for x in xs for y in ys]
yys_flat = [y for x in xs for y in ys]

# Plot the vector field
quiver!(xxs_flat, yys_flat, quiver=(vec(U)/20, vec(V)/20), color=:green, quiverkeyscale=0.5)

Now we apply a Gaussian kernel to the sample points to create the kernel density estimate:

Show the code
using Plots, Distributions, KernelDensity

# Convert points to x and y vectors
x_points = [x for (x, y) in points]
y_points = [y for (x, y) in points]

# Perform kernel density estimation using KernelDensity.jl
parzen = kde((y_points, x_points); boundary=((-3,3),(-3,3)), bandwidth = (.3,.3))

# Plot the ground truth PDF
p1 = heatmap(
    -3:0.01:3, -3:0.01:3, p,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Ground Truth PDF q(x)",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

# Scatter plot of the sampled points on top of the ground truth PDF
scatter!(p1, x_points, y_points, label="Sampled Points", color=:red, ms=2)


# Plot the Parzen density estimate
p2 = heatmap(
    parzen.x, parzen.y, parzen.density,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Kernel Density Estimate",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

# Scatter plot of the sampled points on top of the Parzen density estimate
scatter!(p2, x_points,  y_points, label="Sampled Points", color=:red, ms=2)

# Arrange the plots side by side
plot(p1, p2, layout = @layout([a b]), size=(800, 400))

Now looking at the density estimate across many bandwidths, we can see the effect on adding more and more noise to the original sampled points and our density estimate that we are learning. At very large bandwidths the estimate becomes a uniform distribution.

Show the code
using Plots, Distributions, KernelDensity
# Define the range of bandwidths for the animation
bandwidths = [(0.01 + 0.05 * i, 0.01 + 0.05 * i) for i in 0:40]

# Create the animation
anim = @animate for bw in bandwidths
    kde_result = kde((x_points,y_points); boundary=((-6, 6), (-6, 6)), bandwidth=bw)

    p2 = heatmap(
        kde_result.x, kde_result.y, kde_result.density',
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Kernel Density Estimate,Bandwidth = $(round(bw[1],digits=2))",
        xlims=(-6, 6), ylims=(-6, 6),
        xticks=[-6, 6], yticks=[-6, 6]
    )

    scatter!(p2, x_points, y_points, label="Sampled Points", color=:red, ms=2)
end

# Save the animation as a GIF
gif(anim, "parzen_density_animation_with_gradients.gif", fps=2,show_msg = false)

Now we can compute the score of the kernel density estimate to see how it changes with the bandwidth. The score function of the distribution is numerically unstable at regions of sparse data. Recalling that the score is the gradient of the log-density funtion, when the density is very low the function approaches negative infinity. Within the limits of numerical precision, taking the log of the density function will result in a negative infinity in sparse and low probability regions. Higher bandwidths of KDE using the Gaussian kernel for example, spread out both the discrete sampling and the true distribution over space. This extends the region of numerical stability for a higher bandwidth.

Show the code
using Plots, Distributions, KernelDensity, ForwardDiff

# Define the range of bandwidths for the animation
bandwidths = [(0.01 + 0.05 * i, 0.01 + 0.05 * i) for i in 0:30]

boundary = (-10, 10)
# Create the animation
anim = @animate for bw in bandwidths
    kde_result = kde((x_points, y_points); boundary=(boundary, boundary), bandwidth=bw)

        # Compute log-density
    log_density = log.(kde_result.density)

    # Compute gradients of log-density
    grad_x = zeros(size(log_density))
    grad_y = zeros(size(log_density))

    # Compute gradients using finite difference centered difference
    for i in 2:size(log_density, 1)-1
        for j in 2:size(log_density, 2)-1
            grad_x[i, j] = (log_density[i+1, j] - log_density[i-1, j]) / (kde_result.x[i+1] - kde_result.x[i-1])
            grad_y[i, j] = (log_density[i, j+1] - log_density[i, j-1]) / (kde_result.y[j+1] - kde_result.y[j-1])
        end
    end
    # Downsample the gradients and coordinates by selecting every 10th point
    downsample_indices_x = 1:10:size(grad_x, 1)
    downsample_indices_y = 1:10:size(grad_y, 2)

    grad_x_downsampled = grad_x[downsample_indices_x, downsample_indices_y]
    grad_y_downsampled = grad_y[downsample_indices_x, downsample_indices_y]

    x_downsampled = kde_result.x[downsample_indices_x]
    y_downsampled = kde_result.y[downsample_indices_y]

    xxs_flat = repeat(x_downsampled, inner=[length(y_downsampled)])
    yys_flat = repeat(y_downsampled, outer=[length(x_downsampled)])

    grad_x_flat = grad_x_downsampled[:]
    grad_y_flat = grad_y_downsampled[:]

    # Plot heatmaps of the gradients
    p1 = heatmap(
        kde_result.x, kde_result.y, grad_x',
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Gradient of Log-Density wrt x \n Bandwidth = $(round(bw[1],digits=2))",
        xlims=boundary, ylims=boundary
    )

    # Overlay the scatter plot of the sampled points
    scatter!(p1, x_points, y_points, label="Sampled Points", color=:red, ms=2)

    p2 = heatmap(
        kde_result.x, kde_result.y, grad_y',
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Gradient of Log-Density wrt y \n Bandwidth = $(round(bw[1],digits=2))",
        xlims=boundary, ylims=boundary
    )

    # Overlay the scatter plot of the sampled points
    scatter!(p2, x_points, y_points, label="Sampled Points", color=:red, ms=2)

    plot(p1, p2, layout = @layout([a b]), size=(800, 400))
end
# Save the animation as a GIF
gif(anim, "parzen_density_gradient_animation_with_gradients.gif", fps=2, show_msg=false)

And combining the gradient overtop of the ground truth distribution that is modeled with the kernel density estimate, starting with the larger bandwidths and moving to the smaller bandwidths, we can see that the region of numerical stability is extended with the larger bandwidths. For a random point in the sample space

Show the code
# Define the range of bandwidths for the animation
bandwidths = [(0.01 + 0.2 * i, 0.01 + 0.2 * i) for i in 0:10]
bandwidths = reverse(bandwidths)

boundary = (-10, 10)
# Create the animation
anim = @animate for bw in bandwidths
    kde_result = kde((x_points, y_points); boundary=(boundary, boundary), bandwidth=bw)

    # Compute log-density
    log_density = log.(kde_result.density)

    # Compute gradients of log-density
    grad_x = zeros(size(log_density))
    grad_y = zeros(size(log_density))

    # Compute gradients using finite difference centered difference
    for i in 2:size(log_density, 1)-1
        for j in 2:size(log_density, 2)-1
            grad_x[i, j] = (log_density[i+1, j] - log_density[i-1, j]) / (kde_result.x[i+1] - kde_result.x[i-1])
            grad_y[i, j] = (log_density[i, j+1] - log_density[i, j-1]) / (kde_result.y[j+1] - kde_result.y[j-1])
        end
    end
    # Downsample the gradients and coordinates by selecting every 10th point
    downsample_indices_x = 1:20:size(grad_x, 1)
    downsample_indices_y = 1:20:size(grad_y, 2)

    grad_x_downsampled = grad_x[downsample_indices_x, downsample_indices_y]
    grad_y_downsampled = grad_y[downsample_indices_x, downsample_indices_y]

    x_downsampled = kde_result.x[downsample_indices_x]
    y_downsampled = kde_result.y[downsample_indices_y]

    xxs_flat = repeat(x_downsampled, inner=[length(y_downsampled)])
    yys_flat = repeat(y_downsampled, outer=[length(x_downsampled)])

    grad_x_flat = grad_x_downsampled[:]
    grad_y_flat = grad_y_downsampled[:]

     # Plot the actual distribution
    x_range = boundary[1]:0.01:boundary[2]
    y_range = boundary[1]:0.01:boundary[2]
    p1 = heatmap(
        x_range, y_range, p,
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Ground Truth PDF q(x)\n with score of Kernel Density Estimate, \n Bandwidth = $(round(bw[1],digits=2))",
        xlims=boundary, ylims=boundary,
        size=(800, 800)
    )

    # Plot a quiver plot of the downsampled gradients
    quiver!(yys_flat, xxs_flat, quiver=(grad_x_flat/10, grad_y_flat/10), 
    color=:green, quiverkeyscale=0.5, aspect_ratio=:equal)
end
# Save the animation as a GIF
gif(anim, "parzen_density_gradient_animation_with_gradients.gif", fps=2, show_msg=false)